from torch.autograd import Variable
import torch
import torch.nn.functional as F


def FGSM(model, x, y, optimizer, args):
	model.eval()
	epsilon = args.eps
	x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()

	x_adv.requires_grad_()
	with torch.enable_grad():
		logits_adv = model(x_adv)
		loss = F.cross_entropy(logits_adv, y)
	grad = torch.autograd.grad(loss, [x_adv])[0]
	x_adv = x_adv.detach() + epsilon * torch.sign(grad.detach())
	x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
	x_adv = torch.clamp(x_adv, 0.0, 1.0)
	model.train()
	x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
	# zero gradient
	model.zero_grad()
	optimizer.zero_grad()
	logits = model(x_adv)
	loss = F.cross_entropy(logits, y)
	return logits, loss

def PGD(model, x, y, optimizer, args):
	model.eval()
	epsilon = args.eps
	num_steps = args.ns
	step_size = args.ss
	x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()

	for _ in range(num_steps):
		x_adv.requires_grad_()
		with torch.enable_grad():
			logits_adv = model(x_adv)
			loss = F.cross_entropy(logits_adv, y)
		grad = torch.autograd.grad(loss, [x_adv])[0]
		x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
		x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
		x_adv = torch.clamp(x_adv, 0.0, 1.0)

	model.train()
	x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
	# zero gradient
	model.zero_grad()
	optimizer.zero_grad()
	logits = model(x_adv)
	loss = F.cross_entropy(logits, y)
	return logits, loss

def TRADES(model, x, y, optimizer, args):
		model.eval()
		epsilon = args.eps
		num_steps = args.ns
		step_size = args.ss
		beta = args.beta
		x_adv = x.detach() + 0.001 * torch.randn_like(x).detach()
		nat_output = model(x)
		for _ in range(num_steps):
			x_adv.requires_grad_()
			with torch.enable_grad():
				logits_adv = model(x_adv)
				loss_kl = F.kl_div(F.log_softmax(logits_adv, dim=1),
					                   F.softmax(nat_output, dim=1), reduction='sum')

			grad = torch.autograd.grad(loss_kl, [x_adv])[0]
			x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
			x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
			x_adv = torch.clamp(x_adv, 0.0, 1.0)
		model.train()
		x_adv = Variable(x_adv, requires_grad=False)
		optimizer.zero_grad()
		# calculate robust loss
		logits = model(x)
		adv_logits = model(x_adv)
		loss_natural = F.cross_entropy(logits, y)
		loss_robust = F.kl_div(F.log_softmax(adv_logits + 1e-20, dim=1) + 1e-20,
		                       F.softmax(logits, dim=1) + 1e-20, reduction='sum') / x.shape[0]
		loss = loss_natural + beta * loss_robust
		return adv_logits, loss
class PGD_TE():
	def __init__(self, num_samples=50000, num_classes=10, momentum=0.9, es=90, step_size=0.003, epsilon=0.031,
				 perturb_steps=10, norm='linf'):
		# initialize soft labels to onthot vectors
		print('number samples: ', num_samples, 'num_classes: ', num_classes)
		self.soft_labels = torch.zeros(num_samples, num_classes, dtype=torch.float).cuda(non_blocking=True)
		self.momentum = momentum
		self.es = es
		self.step_size = step_size
		self.epsilon = epsilon
		self.perturb_steps = perturb_steps
		self.norm = norm

	def __call__(self, x_natural, y, index, epoch, model, optimizer, weight):
		model.eval()
		batch_size = len(x_natural)
		logits = model(x_natural)

		if epoch >= self.es:
			prob = F.softmax(logits.detach(), dim=1)
			self.soft_labels[index] = self.momentum * self.soft_labels[index] + (1 - self.momentum) * prob
			soft_labels_batch = self.soft_labels[index] / self.soft_labels[index].sum(1, keepdim=True)

		# generate adversarial example
		if self.norm == 'linf':
			x_adv = x_natural.detach() + torch.FloatTensor(*x_natural.shape).uniform_(-self.epsilon,
																					  self.epsilon).cuda()
		else:
			x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
		for _ in range(self.perturb_steps):
			x_adv.requires_grad_()
			with torch.enable_grad():
				logits_adv = model(x_adv)
				if epoch >= self.es:
					loss = F.cross_entropy(logits_adv, y) + weight * (
								(F.softmax(logits_adv, dim=1) - soft_labels_batch) ** 2).mean()
				else:
					loss = F.cross_entropy(logits_adv, y)
			grad = torch.autograd.grad(loss, [x_adv])[0]
			if self.norm == 'linf':
				x_adv = x_adv.detach() + self.step_size * torch.sign(grad.detach())
				x_adv = torch.min(torch.max(x_adv, x_natural - self.epsilon), x_natural + self.epsilon)
			elif self.norm == 'l2':
				g_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, 1, 1, 1)
				scaled_grad = grad.detach() / (g_norm.detach() + 1e-10)
				x_adv = x_natural + (x_adv.detach() + self.step_size * scaled_grad - x_natural).view(x_natural.size(0),
																									 -1).renorm(p=2,
																												dim=0,
																												maxnorm=self.epsilon).view_as(
					x_natural)
			x_adv = torch.clamp(x_adv, 0.0, 1.0)

		# compute loss
		model.train()
		optimizer.zero_grad()
		x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)

		# calculate robust loss
		logits = model(x_adv)
		if epoch >= self.es:
			loss = F.cross_entropy(logits, y) + weight * ((F.softmax(logits, dim=1) - soft_labels_batch) ** 2).mean()
		else:
			loss = F.cross_entropy(logits, y)
		return logits, loss

def PGD_new(model, x, y, optimizer, epoch, args):
	model.eval()
	epsilon = args.eps
	num_steps = args.ns
	step_size = args.ss
	if epoch >= args.pgd:
		x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()

		for _ in range(num_steps):
			x_adv.requires_grad_()
			with torch.enable_grad():
				logits_adv = model(x_adv)
				loss = F.cross_entropy(logits_adv, y)
			grad = torch.autograd.grad(loss, [x_adv])[0]
			x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
			x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
			x_adv = torch.clamp(x_adv, 0.0, 1.0)

	else:
		x_adv = x

	model.train()
	x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
	# zero gradient
	optimizer.zero_grad()
	logits = model(x_adv)
	loss = F.cross_entropy(logits, y)
	return logits, loss

class PGD_TE_new():
	def __init__(self, num_samples=50000, num_classes=10, momentum=0.9, es=90, step_size=0.003, epsilon=0.031,
				 perturb_steps=10, norm='linf', pgd=100):
		# initialize soft labels to onthot vectors
		print('number samples: ', num_samples, 'num_classes: ', num_classes)
		self.soft_labels = torch.zeros(num_samples, num_classes, dtype=torch.float).cuda(non_blocking=True)
		self.momentum = momentum
		self.es = es
		self.step_size = step_size
		self.epsilon = epsilon
		self.perturb_steps = perturb_steps
		self.norm = norm
		self.pgd = pgd

	def __call__(self, x_natural, y, index, epoch, model, optimizer, weight):
		model.eval()
		batch_size = len(x_natural)
		logits = model(x_natural)

		if epoch >= self.es:
			prob = F.softmax(logits.detach(), dim=1)
			self.soft_labels[index] = self.momentum * self.soft_labels[index] + (1 - self.momentum) * prob
			soft_labels_batch = self.soft_labels[index] / self.soft_labels[index].sum(1, keepdim=True)
		if epoch >= self.pgd:
			# generate adversarial example
			if self.norm == 'linf':
				x_adv = x_natural.detach() + torch.FloatTensor(*x_natural.shape).uniform_(-self.epsilon,
																						  self.epsilon).cuda()
			else:
				x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
			for _ in range(self.perturb_steps):
				x_adv.requires_grad_()
				with torch.enable_grad():
					logits_adv = model(x_adv)
					if epoch >= self.es:
						loss = F.cross_entropy(logits_adv, y) + weight * (
									(F.softmax(logits_adv, dim=1) - soft_labels_batch) ** 2).mean()
					else:
						loss = F.cross_entropy(logits_adv, y)
				grad = torch.autograd.grad(loss, [x_adv])[0]
				if self.norm == 'linf':
					x_adv = x_adv.detach() + self.step_size * torch.sign(grad.detach())
					x_adv = torch.min(torch.max(x_adv, x_natural - self.epsilon), x_natural + self.epsilon)
				elif self.norm == 'l2':
					g_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, 1, 1, 1)
					scaled_grad = grad.detach() / (g_norm.detach() + 1e-10)
					x_adv = x_natural + (x_adv.detach() + self.step_size * scaled_grad - x_natural).view(x_natural.size(0),
																										 -1).renorm(p=2,
																													dim=0,
																													maxnorm=self.epsilon).view_as(
						x_natural)
				x_adv = torch.clamp(x_adv, 0.0, 1.0)
			x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
		else:
			x_adv = x_natural
		# compute loss
		model.train()
		optimizer.zero_grad()
		# calculate robust loss
		logits = model(x_adv)
		if epoch >= self.es:
			loss = F.cross_entropy(logits, y) + weight * ((F.softmax(logits, dim=1) - soft_labels_batch) ** 2).mean()
		else:
			loss = F.cross_entropy(logits, y)
		return logits, loss

def FGSM_overfitting(model, x, y, optimizer, args):
	model.eval()
	epsilon = args.eps
	x_adv = x.detach()
	x_adv.requires_grad_()
	with torch.enable_grad():
		logits_adv = model(x_adv)
		loss = F.cross_entropy(logits_adv, y)
	grad = torch.autograd.grad(loss, [x_adv])[0]
	x_adv = x_adv.detach() + epsilon * torch.sign(grad.detach())
	x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
	x_adv = torch.clamp(x_adv, 0.0, 1.0)
	model.train()
	x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
	# zero gradient
	model.zero_grad()
	optimizer.zero_grad()
	logits = model(x_adv)
	loss = F.cross_entropy(logits, y)
	return logits, loss


def pgd(model, x, y):
	model.eval()
	epsilon = 8./255.
	num_steps = 20
	step_size = 2./255.

	x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()

	for _ in range(num_steps):
		x_adv.requires_grad_()
		with torch.enable_grad():
			logits_adv = model(x_adv)
			loss = F.cross_entropy(logits_adv, y)
		grad = torch.autograd.grad(loss, [x_adv])[0]
		x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
		x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
		x_adv = torch.clamp(x_adv, 0.0, 1.0)

	x_adv = torch.clamp(x_adv, 0.0, 1.0)

	return x_adv

def fgsm(model, x, y):
	model.eval()
	epsilon = 8./255.

	x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()

	x_adv.requires_grad_()
	with torch.enable_grad():
		logits_adv = model(x_adv)
		loss = F.cross_entropy(logits_adv, y)
	grad = torch.autograd.grad(loss, [x_adv])[0]
	x_adv = x_adv.detach() + epsilon * torch.sign(grad.detach())
	x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
	x_adv = torch.clamp(x_adv, 0.0, 1.0)

	x_adv = torch.clamp(x_adv, 0.0, 1.0)

	return x_adv

